{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#skip\n",
    "! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#default_exp data.block"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from fastai.torch_basics import *\n",
    "from fastai.data.core import *\n",
    "from fastai.data.load import *\n",
    "from fastai.data.external import *\n",
    "from fastai.data.transforms import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data block\n",
    "\n",
    "> High level API to quickly get your data in a `DataLoaders`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## TransformBlock -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class TransformBlock():\n",
    "    \"A basic wrapper that links defaults transforms for the data block API\"\n",
    "    def __init__(self, type_tfms=None, item_tfms=None, batch_tfms=None, dl_type=None, dls_kwargs=None):\n",
    "        self.type_tfms  =            L(type_tfms)\n",
    "        self.item_tfms  = ToTensor + L(item_tfms)\n",
    "        self.batch_tfms =            L(batch_tfms)\n",
    "        self.dl_type,self.dls_kwargs = dl_type,({} if dls_kwargs is None else dls_kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def CategoryBlock(vocab=None, sort=True, add_na=False):\n",
    "    \"`TransformBlock` for single-label categorical targets\"\n",
    "    return TransformBlock(type_tfms=Categorize(vocab=vocab, sort=sort, add_na=add_na))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def MultiCategoryBlock(encoded=False, vocab=None, add_na=False):\n",
    "    \"`TransformBlock` for multi-label categorical targets\"\n",
    "    tfm = EncodedMultiCategorize(vocab=vocab) if encoded else [MultiCategorize(vocab=vocab, add_na=add_na), OneHotEncode]\n",
    "    return TransformBlock(type_tfms=tfm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def RegressionBlock(n_out=None):\n",
    "    \"`TransformBlock` for float targets\"\n",
    "    return TransformBlock(type_tfms=RegressionSetup(c=n_out))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## General API"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from inspect import isfunction,ismethod"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def _merge_grouper(o):\n",
    "    if isinstance(o, LambdaType): return id(o)\n",
    "    elif isinstance(o, type): return o\n",
    "    elif (isfunction(o) or ismethod(o)): return o.__qualname__\n",
    "    return o.__class__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def _merge_tfms(*tfms):\n",
    "    \"Group the `tfms` in a single list, removing duplicates (from the same class) and instantiating\"\n",
    "    g = groupby(concat(*tfms), _merge_grouper)\n",
    "    return L(v[-1] for k,v in g.items()).map(instantiate)\n",
    "\n",
    "def _zip(x): return L(x).zip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#For example, so not exported\n",
    "from fastai.vision.core import *\n",
    "from fastai.vision.data import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "tfms = _merge_tfms([Categorize, MultiCategorize, Categorize(['dog', 'cat'])], Categorize(['a', 'b']))\n",
    "#If there are several instantiated versions, the last one is kept.\n",
    "test_eq(len(tfms), 2)\n",
    "test_eq(tfms[1].__class__, MultiCategorize)\n",
    "test_eq(tfms[0].__class__, Categorize)\n",
    "test_eq(tfms[0].vocab, ['a', 'b'])\n",
    "\n",
    "tfms = _merge_tfms([PILImage.create, PILImage.show])\n",
    "#Check methods are properly separated\n",
    "test_eq(len(tfms), 2)\n",
    "tfms = _merge_tfms([show_image, set_trace])\n",
    "#Check functions are properly separated\n",
    "test_eq(len(tfms), 2)\n",
    "\n",
    "_f = lambda x: 0\n",
    "test_eq(len(_merge_tfms([_f,lambda x: 1])), 2)\n",
    "test_eq(len(_merge_tfms([_f,_f])), 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@docs\n",
    "@funcs_kwargs\n",
    "class DataBlock():\n",
    "    \"Generic container to quickly build `Datasets` and `DataLoaders`\"\n",
    "    get_x=get_items=splitter=get_y = None\n",
    "    blocks,dl_type = (TransformBlock,TransformBlock),TfmdDL\n",
    "    _methods = 'get_items splitter get_y get_x'.split()\n",
    "    _msg = \"If you wanted to compose several transforms in your getter don't forget to wrap them in a `Pipeline`.\"\n",
    "    def __init__(self, blocks=None, dl_type=None, getters=None, n_inp=None, item_tfms=None, batch_tfms=None, **kwargs):\n",
    "        blocks = L(self.blocks if blocks is None else blocks)\n",
    "        blocks = L(b() if callable(b) else b for b in blocks)\n",
    "        self.type_tfms = blocks.attrgot('type_tfms', L())\n",
    "        self.default_item_tfms  = _merge_tfms(*blocks.attrgot('item_tfms',  L()))\n",
    "        self.default_batch_tfms = _merge_tfms(*blocks.attrgot('batch_tfms', L()))\n",
    "        for b in blocks:\n",
    "            if getattr(b, 'dl_type', None) is not None: self.dl_type = b.dl_type\n",
    "        if dl_type is not None: self.dl_type = dl_type\n",
    "        self.dataloaders = delegates(self.dl_type.__init__)(self.dataloaders)\n",
    "        self.dls_kwargs = merge(*blocks.attrgot('dls_kwargs', {}))\n",
    "\n",
    "        self.n_inp = ifnone(n_inp, max(1, len(blocks)-1))\n",
    "        self.getters = ifnone(getters, [noop]*len(self.type_tfms))\n",
    "        if self.get_x:\n",
    "            if len(L(self.get_x)) != self.n_inp:\n",
    "                raise ValueError(f'get_x contains {len(L(self.get_x))} functions, but must contain {self.n_inp} (one for each input)\\n{self._msg}')\n",
    "            self.getters[:self.n_inp] = L(self.get_x)\n",
    "        if self.get_y:\n",
    "            n_targs = len(self.getters) - self.n_inp\n",
    "            if len(L(self.get_y)) != n_targs:\n",
    "                raise ValueError(f'get_y contains {len(L(self.get_y))} functions, but must contain {n_targs} (one for each target)\\n{self._msg}')\n",
    "            self.getters[self.n_inp:] = L(self.get_y)\n",
    "\n",
    "        if kwargs: raise TypeError(f'invalid keyword arguments: {\", \".join(kwargs.keys())}')\n",
    "        self.new(item_tfms, batch_tfms)\n",
    "\n",
    "    def _combine_type_tfms(self): return L([self.getters, self.type_tfms]).map_zip(\n",
    "        lambda g,tt: (g.fs if isinstance(g, Pipeline) else L(g)) + tt)\n",
    "\n",
    "    def new(self, item_tfms=None, batch_tfms=None):\n",
    "        self.item_tfms  = _merge_tfms(self.default_item_tfms,  item_tfms)\n",
    "        self.batch_tfms = _merge_tfms(self.default_batch_tfms, batch_tfms)\n",
    "        return self\n",
    "\n",
    "    @classmethod\n",
    "    def from_columns(cls, blocks=None, getters=None, get_items=None, **kwargs):\n",
    "        if getters is None: getters = L(ItemGetter(i) for i in range(2 if blocks is None else len(L(blocks))))\n",
    "        get_items = _zip if get_items is None else compose(get_items, _zip)\n",
    "        return cls(blocks=blocks, getters=getters, get_items=get_items, **kwargs)\n",
    "\n",
    "    def datasets(self, source, verbose=False):\n",
    "        self.source = source                     ; pv(f\"Collecting items from {source}\", verbose)\n",
    "        items = (self.get_items or noop)(source) ; pv(f\"Found {len(items)} items\", verbose)\n",
    "        splits = (self.splitter or RandomSplitter())(items)\n",
    "        pv(f\"{len(splits)} datasets of sizes {','.join([str(len(s)) for s in splits])}\", verbose)\n",
    "        return Datasets(items, tfms=self._combine_type_tfms(), splits=splits, dl_type=self.dl_type, n_inp=self.n_inp, verbose=verbose)\n",
    "\n",
    "    def dataloaders(self, source, path='.', verbose=False, **kwargs):\n",
    "        dsets = self.datasets(source, verbose=verbose)\n",
    "        kwargs = {**self.dls_kwargs, **kwargs, 'verbose': verbose}\n",
    "        return dsets.dataloaders(path=path, after_item=self.item_tfms, after_batch=self.batch_tfms, **kwargs)\n",
    "\n",
    "    _docs = dict(new=\"Create a new `DataBlock` with other `item_tfms` and `batch_tfms`\",\n",
    "                 datasets=\"Create a `Datasets` object from `source`\",\n",
    "                 dataloaders=\"Create a `DataLoaders` object from `source`\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To build a `DataBlock` you need to give the library four things: the types of your input/labels, and at least two functions: `get_items` and `splitter`. You may also need to include `get_x` and `get_y` or a more generic list of `getters` that are applied to the results of `get_items`.\n",
    "\n",
    "Once those are provided, you automatically get a `Datasets` or a `DataLoaders`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"DataBlock.datasets\" class=\"doc_header\"><code>DataBlock.datasets</code><a href=\"__main__.py#L51\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>DataBlock.datasets</code>(**`source`**, **`verbose`**=*`False`*)\n",
       "\n",
       "Create a [`Datasets`](/data.core#Datasets) object from `source`"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(DataBlock.datasets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"DataBlock.dataloaders\" class=\"doc_header\"><code>DataBlock.dataloaders</code><a href=\"__main__.py#L58\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>DataBlock.dataloaders</code>(**`source`**, **`path`**=*`'.'`*, **`verbose`**=*`False`*, **`bs`**=*`64`*, **`shuffle`**=*`False`*, **`num_workers`**=*`None`*, **`do_setup`**=*`True`*, **`pin_memory`**=*`False`*, **`timeout`**=*`0`*, **`batch_size`**=*`None`*, **`drop_last`**=*`False`*, **`indexed`**=*`None`*, **`n`**=*`None`*, **`device`**=*`None`*, **`wif`**=*`None`*, **`before_iter`**=*`None`*, **`after_item`**=*`None`*, **`before_batch`**=*`None`*, **`after_batch`**=*`None`*, **`after_iter`**=*`None`*, **`create_batches`**=*`None`*, **`create_item`**=*`None`*, **`create_batch`**=*`None`*, **`retain`**=*`None`*, **`get_idxs`**=*`None`*, **`sample`**=*`None`*, **`shuffle_fn`**=*`None`*, **`do_batch`**=*`None`*)\n",
       "\n",
       "Create a [`DataLoaders`](/data.core#DataLoaders) object from `source`"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#hide_input\n",
    "dblock = DataBlock()\n",
    "show_doc(dblock.dataloaders, name=\"DataBlock.dataloaders\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can create a `DataBlock` by passing functions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist = DataBlock(blocks = (ImageBlock(cls=PILImageBW),CategoryBlock),\n",
    "                  get_items = get_image_files,\n",
    "                  splitter = GrandparentSplitter(),\n",
    "                  get_y = parent_label)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Each type comes with default transforms that will be applied\n",
    "- at the base level to create items in a tuple (usually input,target) from the base elements (like filenames)\n",
    "- at the item level of the datasets\n",
    "- at the batch level\n",
    "\n",
    "They are called respectively type transforms, item transforms, batch transforms. In the case of MNIST, the type transforms are the method to create a `PILImageBW` (for the input) and the `Categorize` transform (for the target), the item transform is `ToTensor` and the batch transforms are `Cuda` and `IntToFloatTensor`. You can add any other transforms by passing them in `DataBlock.datasets` or `DataBlock.dataloaders`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(mnist.type_tfms[0], [PILImageBW.create])\n",
    "test_eq(mnist.type_tfms[1].map(type), [Categorize])\n",
    "test_eq(mnist.default_item_tfms.map(type), [ToTensor])\n",
    "test_eq(mnist.default_batch_tfms.map(type), [IntToFloatTensor])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAHsAAACLCAYAAABBVeZmAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAE3klEQVR4nO2dTyh0axzHf4/srpAkpRQJCTvkioQsbBSykO4srGSDsvRnIVvq9a+s2Chyy07dEq5YKSlFKVnpXlLcu7DAuav3NL/nvWbMOzPnnJnv91PqfJs508On3/l5znPOGeM4jhAMMvweAPEOygaCsoGgbCAoGwjKBoKygYCSbYz51/p5N8Z883tcXpHp9wC8xHGcrO/bxphfROQvEdn2b0TeAlXZFn0i8reI/On3QLwCWXZIRDYcoPPFBuh3dTHGFIvIrYiUOY5z6/d4vAK1sn8TkWMk0SLYstf9HoTXwB3GjTG/isgfIlLoOM4/fo/HSxArOyQiv6OJFgGsbGQQKxsWygaCsoGgbCAoG4hoq178Vz31MJ+9wMoGgrKBoGwgKBsIygaCsoGgbCAoGwjKBoKygaBsICgbCMoGgrKBoGwgKBsIygaCsoGgbCAoGwjKBoKygaBsICgbCMoGgrKBSNmH3r29val8fn6ucn19vcoFBQXu9sTEhHrt+vpa5bW1NZXb2tpU7urqim2wFrW1tSq3t7e72xkZyas/VjYQlA0EZQMR7QE6gbll9/n5WeWRkRGVNzc3vRxOQnl4eHC38/Ly4v043rJLKBsKygYisPPs19dXlZubm1W+vLyM6fNmZmbc7c7OzojvXVxcVPn09FTl29v4nm+bn5+vcmamNxpY2UBQNhCUDURgerbdo4eHh1WO1qON0dPLoaEhlcfGxtztrKwsiURDQ4PKLy8vKu/v76vc29sb8fNsenp6VM7Ozo5p/5+FlQ0EZQNB2UAE5tz43d2dyqWlpTHt39TUpPLR0VHcY/rOwcGByoODgyrf399H3L+oqEjlw8NDlUtKSn5+cD/Cc+OEsqGgbCACM8/Ozc1VORQKqby+rr+GK3zeLCIyOzubnIGJyO7ursrRerTN1taWygnu0V+GlQ0EZQNB2UAEZp5t8/HxobJ9nbh9fXU8a8L232B6elrlubm5iO+3mZycVHlqakrlZF4bLpxnExHKhiKwh/Fk8v7+rrJ9arWjoyPi/vZyaktLi8r2EqjH8DBOKBsKygYCsmff3NyoXFFREdP+ZWVlKtu3/PoMezahbCgoG4jALHEmm8fHR3d7ZWUlpn2Li4tVPj4+TsiYvIaVDQRlA0HZQKTtPNteEh0YGHC3d3Z2Iu5rXzZk9+jCwsI4R5dUOM8mlA0FZQORNvNs+zKm8fFxlaP16XD29vZUDniP/jKsbCAoGwjKBiJtevb29rbKS0tLn77XPte9sLCgcqy3C6cKrGwgKBsIygYiZXu23aPDz31Ho66uTuXu7u6EjCnosLKBoGwgKBuIlFnPvrq6UrmxsVFl+5GTNqOjo+62fUut/YiPFIfr2YSyoaBsIALbs+1HTre2tqr89PQUcf+amhqVwx9JmWY92oY9m1A2FJQNRGB6dvhXFIqIVFdXqxx+r9b/UVlZqfLJyYnKOTk5cYwupWDPJpQNRWCWOJeXl1WOdtiuqqpSeXV1VWWgw/aXYWUDQdlAUDYQvvXss7Mzlefn52Pav6+vT2X723/Ij7CygaBsICgbCN96dnl5ucr2oy0uLi5U3tjYULm/vz85A0tjWNlAUDYQlA1EYJY4ScLgEiehbCgoG4ho8+xPj/8k9WBlA0HZQFA2EJQNBGUDQdlA/AdDRSN/yb0OvgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 144x144 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "dsets = mnist.datasets(untar_data(URLs.MNIST_TINY))\n",
    "test_eq(dsets.vocab, ['3', '7'])\n",
    "x,y = dsets.train[0]\n",
    "test_eq(x.size,(28,28))\n",
    "show_at(dsets.train, 0, cmap='Greys', figsize=(2,2));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_fail(lambda: DataBlock(wrong_kwarg=42, wrong_kwarg2='foo'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can pass any number of blocks to `DataBlock`, we can then define what are the input and target blocks by changing `n_inp`. For example, defining `n_inp=2` will consider the first two blocks passed as inputs and the others as targets. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist = DataBlock((ImageBlock, ImageBlock, CategoryBlock), get_items=get_image_files, splitter=GrandparentSplitter(),\n",
    "                   get_y=parent_label)\n",
    "dsets = mnist.datasets(untar_data(URLs.MNIST_TINY))\n",
    "test_eq(mnist.n_inp, 2)\n",
    "test_eq(len(dsets.train[0]), 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_fail(lambda: DataBlock((ImageBlock, ImageBlock, CategoryBlock), get_items=get_image_files, splitter=GrandparentSplitter(),\n",
    "                  get_y=[parent_label, noop],\n",
    "                  n_inp=2), msg='get_y contains 2 functions, but must contain 1 (one for each output)')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist = DataBlock((ImageBlock, ImageBlock, CategoryBlock), get_items=get_image_files, splitter=GrandparentSplitter(),\n",
    "                  n_inp=1,\n",
    "                  get_y=[noop, Pipeline([noop, parent_label])])\n",
    "dsets = mnist.datasets(untar_data(URLs.MNIST_TINY))\n",
    "test_eq(len(dsets.train[0]), 3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Debugging"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def _short_repr(x):\n",
    "    if isinstance(x, tuple): return f'({\", \".join([_short_repr(y) for y in x])})'\n",
    "    if isinstance(x, list): return f'[{\", \".join([_short_repr(y) for y in x])}]'\n",
    "    if not isinstance(x, Tensor): return str(x)\n",
    "    if x.numel() <= 20 and x.ndim <=1: return str(x)\n",
    "    return f'{x.__class__.__name__} of size {\"x\".join([str(d) for d in x.shape])}'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "test_eq(_short_repr(TensorImage(torch.randn(40,56))), 'TensorImage of size 40x56')\n",
    "test_eq(_short_repr(TensorCategory([1,2,3])), 'TensorCategory([1, 2, 3])')\n",
    "test_eq(_short_repr((TensorImage(torch.randn(40,56)), TensorImage(torch.randn(32,20)))),\n",
    "        '(TensorImage of size 40x56, TensorImage of size 32x20)')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def _apply_pipeline(p, x):\n",
    "    print(f\"  {p}\\n    starting from\\n      {_short_repr(x)}\")\n",
    "    for f in p.fs:\n",
    "        name = f.name\n",
    "        try:\n",
    "            x = f(x)\n",
    "            if name != \"noop\": print(f\"    applying {name} gives\\n      {_short_repr(x)}\")\n",
    "        except Exception as e:\n",
    "            print(f\"    applying {name} failed.\")\n",
    "            raise e\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from fastai.data.load import _collate_types\n",
    "\n",
    "def _find_fail_collate(s):\n",
    "    s = L(*s)\n",
    "    for x in s[0]:\n",
    "        if not isinstance(x, _collate_types): return f\"{type(x).__name__} is not collatable\"\n",
    "    for i in range_of(s[0]):\n",
    "        try: _ = default_collate(s.itemgot(i))\n",
    "        except:\n",
    "            shapes = [getattr(o[i], 'shape', None) for o in s]\n",
    "            return f\"Could not collate the {i}-th members of your tuples because got the following shapes\\n{','.join([str(s) for s in shapes])}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@patch\n",
    "def summary(self: DataBlock, source, bs=4, show_batch=False, **kwargs):\n",
    "    \"Steps through the transform pipeline for one batch, and optionally calls `show_batch(**kwargs)` on the transient `Dataloaders`.\"\n",
    "    print(f\"Setting-up type transforms pipelines\")\n",
    "    dsets = self.datasets(source, verbose=True)\n",
    "    print(\"\\nBuilding one sample\")\n",
    "    for tl in dsets.train.tls:\n",
    "        _apply_pipeline(tl.tfms, get_first(dsets.train.items))\n",
    "    print(f\"\\nFinal sample: {dsets.train[0]}\\n\\n\")\n",
    "\n",
    "    dls = self.dataloaders(source, bs=bs, verbose=True)\n",
    "    print(\"\\nBuilding one batch\")\n",
    "    if len([f for f in dls.train.after_item.fs if f.name != 'noop'])!=0:\n",
    "        print(\"Applying item_tfms to the first sample:\")\n",
    "        s = [_apply_pipeline(dls.train.after_item, dsets.train[0])]\n",
    "        print(f\"\\nAdding the next {bs-1} samples\")\n",
    "        s += [dls.train.after_item(dsets.train[i]) for i in range(1, bs)]\n",
    "    else:\n",
    "        print(\"No item_tfms to apply\")\n",
    "        s = [dls.train.after_item(dsets.train[i]) for i in range(bs)]\n",
    "\n",
    "    if len([f for f in dls.train.before_batch.fs if f.name != 'noop'])!=0:\n",
    "        print(\"\\nApplying before_batch to the list of samples\")\n",
    "        s = _apply_pipeline(dls.train.before_batch, s)\n",
    "    else: print(\"\\nNo before_batch transform to apply\")\n",
    "\n",
    "    print(\"\\nCollating items in a batch\")\n",
    "    try:\n",
    "        b = dls.train.create_batch(s)\n",
    "        b = retain_types(b, s[0] if is_listy(s) else s)\n",
    "    except Exception as e:\n",
    "        print(\"Error! It's not possible to collate your items in a batch\")\n",
    "        why = _find_fail_collate(s)\n",
    "        print(\"Make sure all parts of your samples are tensors of the same size\" if why is None else why)\n",
    "        raise e\n",
    "\n",
    "    if len([f for f in dls.train.after_batch.fs if f.name != 'noop'])!=0:\n",
    "        print(\"\\nApplying batch_tfms to the batch built\")\n",
    "        b = to_device(b, dls.device)\n",
    "        b = _apply_pipeline(dls.train.after_batch, b)\n",
    "    else: print(\"\\nNo batch_tfms to apply\")\n",
    "\n",
    "    if show_batch: dls.show_batch(**kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"DataBlock.summary\" class=\"doc_header\"><code>DataBlock.summary</code><a href=\"__main__.py#L2\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>DataBlock.summary</code>(**`source`**, **`bs`**=*`4`*, **`show_batch`**=*`False`*, **\\*\\*`kwargs`**)\n",
       "\n",
       "Steps through the transform pipeline for one batch, and optionally calls `show_batch(**kwargs)` on the transient `Dataloaders`."
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(DataBlock.summary)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Besides stepping through the transformation, `summary()`  provides a shortcut `dls.show_batch(...)`, to see the data.  E.g.\n",
    "\n",
    "```\n",
    "pets.summary(path/\"images\", bs=8, show_batch=True, unique=True,...)\n",
    "```\n",
    "\n",
    "is a shortcut to:\n",
    "```\n",
    "pets.summary(path/\"images\", bs=8)\n",
    "dls = pets.dataloaders(path/\"images\", bs=8)\n",
    "dls.show_batch(unique=True,...)  # See different tfms effect on the same image.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converted 00_torch_core.ipynb.\n",
      "Converted 01_layers.ipynb.\n",
      "Converted 02_data.load.ipynb.\n",
      "Converted 03_data.core.ipynb.\n",
      "Converted 04_data.external.ipynb.\n",
      "Converted 05_data.transforms.ipynb.\n",
      "Converted 06_data.block.ipynb.\n",
      "Converted 07_vision.core.ipynb.\n",
      "Converted 08_vision.data.ipynb.\n",
      "Converted 09_vision.augment.ipynb.\n",
      "Converted 09b_vision.utils.ipynb.\n",
      "Converted 09c_vision.widgets.ipynb.\n",
      "Converted 10_tutorial.pets.ipynb.\n",
      "Converted 11_vision.models.xresnet.ipynb.\n",
      "Converted 12_optimizer.ipynb.\n",
      "Converted 13_callback.core.ipynb.\n",
      "Converted 13a_learner.ipynb.\n",
      "Converted 13b_metrics.ipynb.\n",
      "Converted 14_callback.schedule.ipynb.\n",
      "Converted 14a_callback.data.ipynb.\n",
      "Converted 15_callback.hook.ipynb.\n",
      "Converted 15a_vision.models.unet.ipynb.\n",
      "Converted 16_callback.progress.ipynb.\n",
      "Converted 17_callback.tracker.ipynb.\n",
      "Converted 18_callback.fp16.ipynb.\n",
      "Converted 18a_callback.training.ipynb.\n",
      "Converted 19_callback.mixup.ipynb.\n",
      "Converted 20_interpret.ipynb.\n",
      "Converted 20a_distributed.ipynb.\n",
      "Converted 21_vision.learner.ipynb.\n",
      "Converted 22_tutorial.imagenette.ipynb.\n",
      "Converted 23_tutorial.vision.ipynb.\n",
      "Converted 24_tutorial.siamese.ipynb.\n",
      "Converted 24_vision.gan.ipynb.\n",
      "Converted 30_text.core.ipynb.\n",
      "Converted 31_text.data.ipynb.\n",
      "Converted 32_text.models.awdlstm.ipynb.\n",
      "Converted 33_text.models.core.ipynb.\n",
      "Converted 34_callback.rnn.ipynb.\n",
      "Converted 35_tutorial.wikitext.ipynb.\n",
      "Converted 36_text.models.qrnn.ipynb.\n",
      "Converted 37_text.learner.ipynb.\n",
      "Converted 38_tutorial.text.ipynb.\n",
      "Converted 39_tutorial.transformers.ipynb.\n",
      "Converted 40_tabular.core.ipynb.\n",
      "Converted 41_tabular.data.ipynb.\n",
      "Converted 42_tabular.model.ipynb.\n",
      "Converted 43_tabular.learner.ipynb.\n",
      "Converted 44_tutorial.tabular.ipynb.\n",
      "Converted 45_collab.ipynb.\n",
      "Converted 46_tutorial.collab.ipynb.\n",
      "Converted 50_tutorial.datablock.ipynb.\n",
      "Converted 60_medical.imaging.ipynb.\n",
      "Converted 61_tutorial.medical_imaging.ipynb.\n",
      "Converted 65_medical.text.ipynb.\n",
      "Converted 70_callback.wandb.ipynb.\n",
      "Converted 71_callback.tensorboard.ipynb.\n",
      "Converted 72_callback.neptune.ipynb.\n",
      "Converted 73_callback.captum.ipynb.\n",
      "Converted 74_callback.cutmix.ipynb.\n",
      "Converted 97_test_utils.ipynb.\n",
      "Converted 99_pytorch_doc.ipynb.\n",
      "Converted index.ipynb.\n",
      "Converted tutorial.ipynb.\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "from nbdev.export import notebook2script\n",
    "notebook2script()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
